LMCorrector#
The LMCorrector
class is a language model-based corrector that utilizes beam search with distortion probabilities to correct text input. It can be used to fix errors in text based on a pretrained language model.
Key Features#
Utilizes beam search with distortion probabilities
Supports various pretrained language models
Configurable parameters for fine-tuning correction behavior
Supports both batch and streaming correction modes
Usage#
Here’s a basic example of how to use the LMCorrector
:
from lmcsc.corrector import LMCorrector
# Initialize the corrector with a pretrained model
corrector = LMCorrector("gpt2")
# Correct a single sentence
result = corrector("完善农产品上行发展机智。")
print(result) # Output: [('完善农产品上行发展机制。',)]
# Correct multiple sentences
results = corrector(["完善农产品上行发展机智。", "这是一个测试句子。"])
print(results)
# Use streaming mode
for output in corrector("完善农产品上行发展机智。", stream=True):
print(output)
Configuration#
The LMCorrector
can be configured using a YAML configuration file. The default configuration file is located at configs/default_config.yaml
. You can specify a custom configuration file path when initializing the corrector:
corrector = LMCorrector("gpt2", config_path="path/to/custom_config.yaml")
Advanced Usage#
The LMCorrector
class provides several advanced features and customization options:
Custom distortion probabilities
Faithfulness reward
Beam search parameters adjustment
Context-aware correction
For more details on these advanced features, please refer to the class documentation.
API Documentation#
- class lmcsc.corrector.LMCorrector(model: str | LMModel, prompted_model: str | LMModel | None = None, config_path: str = 'configs/default_config.yaml', n_observed_chars: int | None = None, n_beam: int | None = None, n_beam_hyps_to_keep: int | None = None, alpha: float | None = None, temperature: float | None = None, distortion_model_smoothing: float | None = None, use_faithfulness_reward: bool | None = None, customized_distortion_probs: dict | None = None, max_length: int | None = None, use_chat_prompted_model: bool = False, *args, **kwargs)[source]#
Bases:
object
A language model-based corrector that utilizes beam search with distortion probabilities to correct text input. The corrector can be used to fix errors in text based on a pretrained language model.
- Parameters:
model (Union[str, LMModel]) – The pretrained language model or a string identifier of the model.
config_path (str, optional) – Path to the configuration file. Defaults to ‘configs/default_config.yaml’.
n_observed_chars (int, optional) – Number of observed characters for the input. Defaults to None.
n_beam (int, optional) – Number of beams for beam search. Defaults to None.
n_beam_hyps_to_keep (int, optional) – Number of beam hypotheses to keep. Defaults to None.
alpha (float, optional) – Hyperparameter for the length reward during beam search. Defaults to None.
temperature (float, optional) – Temperature for the prompt-based LLM. Defaults to None.
distortion_model_smoothing (float, optional) – Smoothing factor for distortion model probabilities. Defaults to None.
use_faithfulness_reward (bool, optional) – Whether to use faithfulness reward in beam search. Defaults to None.
customized_distortion_probs (dict, optional) – Custom distortion probabilities for different transformation types. Defaults to None.
max_length (int, optional) – Maximum allowed length for the input. Defaults to None.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
Note
Default None means using the default value from the configuration file.
- decorate_model_instance() None [source]#
Decorates the model instance by setting necessary attributes, adding methods, and configuring distortion probabilities and transformation types.
- update_params(**kwargs)[source]#
Updates the parameters of the model and the corrector.
- Parameters:
**kwargs – Arbitrary keyword arguments corresponding to the parameters to update.
- preprocess(src: List[str] | str, contexts: List[str] | str | None = None)[source]#
Preprocesses the source text by cleaning and truncating.
- Parameters:
src (Union[List[str], str]) – Source text or list of source texts.
contexts (Union[List[str], str], optional) – Additional context texts. Defaults to None.
- Returns:
Cleaned source texts and lists of changes made during cleaning.
- Return type:
Tuple[List[str], List[List[Tuple[int, int, str, str]]]]
- postprocess(preds: List[str], ori_srcs: List[str], changes: List[List[Tuple[int, int, str, str]]], append_src_left_over: bool = True)[source]#
Postprocesses the predictions to rebuild the sentences and handle out-of-vocabulary characters.
- Parameters:
preds (List[str]) – List of predicted texts.
ori_srcs (List[str]) – List of original source texts.
changes (List[List[Tuple[int, int, str, str]]]) – Lists of changes made during preprocessing.
append_src_left_over (bool, optional) – Whether to append leftover source text after the max length. Defaults to True.
- Returns:
Postprocessed predictions.
- Return type:
List[List[str]]